{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from detector.anchor_generator import AnchorGenerator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate anchors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "anchor_generator = AnchorGenerator(\n",
    "    strides=[8, 16, 32, 64, 128],\n",
    "    scales=[32, 64, 128, 256, 512],\n",
    "    scale_multipliers=[1.0, 1.4142],\n",
    "    aspect_ratios=[1.0, 2.0, 0.5]\n",
    ")\n",
    "num_anchors_per_location = anchor_generator.num_anchors_per_location\n",
    "\n",
    "HEIGHT, WIDTH = 640, 640\n",
    "postprocessing_anchors = anchor_generator(HEIGHT, WIDTH)\n",
    "raw_anchors = anchor_generator.raw_anchors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.Session() as sess:\n",
    "    anchors = sess.run(raw_anchors)\n",
    "    \n",
    "level = 7\n",
    "stride = 2**level\n",
    "\n",
    "anchors = anchors[level - 3]\n",
    "h = int(np.ceil(HEIGHT/stride))\n",
    "w = int(np.ceil(WIDTH/stride))\n",
    "\n",
    "anchors = anchors.reshape((h, w, num_anchors_per_location, 4))\n",
    "anchors.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Show non clipped anchors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ymin, xmin, ymax, xmax = [anchors[:, :, :, i] for i in range(4)]\n",
    "\n",
    "h, w = (ymax - ymin), (xmax - xmin)\n",
    "cy, cx = ymin + 0.5*h, xmin + 0.5*w\n",
    "\n",
    "centers = np.stack([cy, cx], axis=3)\n",
    "anchor_sizes = np.stack([h, w], axis=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, dpi=100, figsize=(int(5*WIDTH/HEIGHT), 5))\n",
    "unique_centers = centers[:, :, 0, :].reshape(-1, 2)\n",
    "unique_sizes = anchor_sizes[0, 0, :, :]\n",
    "\n",
    "i = 5\n",
    "for j, point in enumerate(unique_centers):\n",
    "    cy, cx = point\n",
    "    color = 'g' if j == i else 'r' \n",
    "    ax.plot([cx], [cy], marker='o', markersize=3, color=color)\n",
    "\n",
    "cy, cx = unique_centers[i] \n",
    "for box in unique_sizes:\n",
    "    h, w = box\n",
    "    xmin, ymin = cx - 0.5*w, cy - 0.5*h\n",
    "    rect = plt.Rectangle(\n",
    "        (xmin, ymin), w, h,\n",
    "        linewidth=1.0, edgecolor='k', facecolor='none'\n",
    "    )\n",
    "    ax.add_patch(rect)\n",
    "\n",
    "plt.xlim([0, WIDTH]);\n",
    "plt.ylim([0, HEIGHT]);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
